import taichi as ti
import numpy as np
import logging
from time import time
from .common import Vec3, Mat3x3, Inf, Ray3
from ..shape import Shape
from ..helper.log import Logger
from dataclasses import dataclass

logging.getLogger('vispy').setLevel(logging.ERROR)

__all__ = [
    "readModel",
    "loadModel",
    "showModel",
]


def genNormal(v: np.ndarray, f: np.ndarray) -> np.ndarray:
    """Generate normal vector for a mesh."""
    r1, r2, r3 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]]
    tri_nn = np.cross(r2 - r1, r3 - r1)
    normals = np.empty_like(v)
    for i in range(3):
        normals[:, i] = np.bincount(f.ravel(), tri_nn[:, i].repeat(3), minlength=len(v))
    size = (normals**2).sum(axis=1) ** 0.5 + 1e-9
    return normals / size[:, np.newaxis]


def readOBJ(filename: str):
    """Reads a mesh from a file and returns a tuple of (vertices, faces, normals).
    copy from vispy.io.wavefront"""
    assert filename.lower().endswith('.obj')
    v, f = [], []
    with open(filename) as obj:
        for line in obj:
            line = line.strip()
            if line.startswith('v '):
                v.append(line.split(' ')[1:])
            elif line.startswith('f '):
                tmp = line.split(' ')[1:]
                f.append([x[: x.find('/')] for x in tmp] if '/' in line else tmp)

    vertices = np.array(v, dtype=np.float32)
    faces = np.array(f, dtype=np.int32) - 1
    return vertices, faces


def readVTK(filename: str, makeSurface: bool = True):
    """参考 https://github.com/chunleili/tiReadVtkTet"""
    pos, tet = [], []
    with open(filename) as f:
        for line in f:
            if line.startswith("POINTS "):
                pos.extend(f.readline().split() for _ in range(int(line.split()[1])))
            elif line.startswith("CELLS "):
                tet.extend(
                    f.readline().split()[1:] for _ in range(int(line.split()[1]))
                )
            elif line.startswith("CELL_TYPES "):
                break
    pos = np.array(pos, dtype=np.float32)
    tet = np.array(tet, dtype=np.int32)
    if makeSurface:
        surf = np.concatenate(
            [tet[:, [0, 2, 1]], tet[:, [0, 1, 3]], tet[:, [0, 3, 2]], tet[:, [1, 2, 3]]]
        )
        _, index, count = np.unique(
            np.sort(surf, axis=1), axis=0, return_index=True, return_counts=True
        )
        surf = surf[index[count == 1]]
        uni, inv = np.unique(surf, return_inverse=True)
        return pos[uni], inv.astype(np.int32).reshape(-1, 3)
    return pos, tet


def readModel(filename: str):
    """从文件中读取模型"""
    if filename.endswith(".vtk"):
        _vertex, _face = readVTK(filename)
        _norm = genNormal(_vertex, _face)
    elif filename.endswith(".obj"):
        _vertex, _face = readOBJ(filename)
        _norm = genNormal(_vertex, _face)
    else:
        from vispy.io import read_mesh

        _vertex, _face, _norm, *_ = read_mesh(filename)
    _vertex -= _vertex.mean(axis=0)
    return _vertex, _face, _norm


@ti.kernel
def SAHSplit(
    tri: ti.types.ndarray(Mat3x3),
    cost: ti.types.ndarray(float),
):
    """Surface Area Heuristic 表面积启发式分割优化"""
    lMin = rMin = Vec3(Inf, Inf, Inf)
    lMax = rMax = -Vec3(Inf, Inf, Inf)
    ti.loop_config(serialize=True)
    for i in range(tri.shape[0]):
        j = tri.shape[0] - 1 - i
        if i > 0:
            lDx = lMax - lMin
            cost[i] += (lDx.x * lDx.y + lDx.x * lDx.z + lDx.y * lDx.z) * i
        lMin = min(lMin, [tri[i][:, k].min() for k in range(3)])
        lMax = max(lMax, [tri[i][:, k].max() for k in range(3)])
        rMin = min(rMin, [tri[j][:, k].min() for k in range(3)])
        rMax = max(rMax, [tri[j][:, k].max() for k in range(3)])
        rDx = rMax - rMin
        cost[j] += (rDx.x * rDx.y + rDx.x * rDx.z + rDx.y * rDx.z) * (i + 1)


@dataclass(slots=True)
class BVHNode:
    index: int
    lBound: np.ndarray
    uBound: np.ndarray
    lChild: int
    rChild: int
    parent: int


def loadModel(filename: str, scale: float = 1, sah: bool = True, mixNorm: bool = True):
    t0 = time()
    _vertex, _face, _norm = readModel(filename)
    _vertex *= scale
    objs = _vertex[_face].astype(np.float32)
    norms = _norm[_face].astype(np.float32)
    Logger.log(f"Loaded {filename} in {time() - t0:.3f} s")
    t0 = time()

    nodes: list[BVHNode] = []
    order = np.arange(len(objs))

    def build(begin: int, end: int) -> int:
        if end - begin == 1:
            lc = rc = -1
            lb, ub = objs[order[begin]].min(axis=0), objs[order[begin]].max(axis=0)
        elif end - begin == 2:
            lc, rc = len(nodes), len(nodes) + 1
            lc_lb, lc_ub = objs[order[begin]].min(axis=0), objs[order[begin]].max(
                axis=0
            )
            rc_lb, rc_ub = objs[order[begin + 1]].min(axis=0), objs[
                order[begin + 1]
            ].max(axis=0)
            nodes.append(BVHNode(begin, lc_lb, lc_ub, -1, -1, lc + 2))
            nodes.append(BVHNode(begin + 1, rc_lb, rc_ub, -1, -1, lc + 2))
            lb = np.minimum(lc_lb, rc_lb)
            ub = np.maximum(lc_ub, rc_ub)
            begin = -1
        else:
            _order = order[begin:end]
            _objs = objs[_order]
            __ord = _objs[:, :, _objs.ptp(axis=(0, 1)).argmax()].sum(axis=1).argsort()
            order[begin:end] = _order[__ord]
            if sah:
                cost = np.zeros(end - begin, dtype=np.float32)
                SAHSplit(_objs[__ord], cost)
                mid = min(int(begin + max(1, cost.argmin())), end - 1)
            else:
                mid = (begin + end) // 2
            lc = build(begin, mid)
            rc = build(mid, end)
            nodes[lc].parent = nodes[rc].parent = len(nodes)
            lb = np.minimum(nodes[lc].lBound, nodes[rc].lBound)
            ub = np.maximum(nodes[lc].uBound, nodes[rc].uBound)
            begin = -1
        nodes.append(BVHNode(begin, lb, ub, lc, rc, -1))
        return len(nodes) - 1

    build(0, len(objs))
    objs = objs[order]
    norms = norms[order]

    class Model(Shape):
        n = len(nodes)
        data = Mat3x3.field(shape=len(objs))
        norm = Mat3x3.field(shape=len(norms))
        index = ti.field(dtype=int, shape=n)
        missNext = ti.field(dtype=int, shape=n)
        hitNext = ti.field(dtype=int, shape=n)
        lBound = Vec3.field(shape=n)
        uBound = Vec3.field(shape=n)

        @ti.func
        def hitTriangle(self, curr: int, ray: Ray3, closest: float):
            mat = Model.data[Model.index[curr]]
            a, b, c = mat[0, :], mat[1, :], mat[2, :]
            T, E1, E2 = ray.origin - a, b - a, c - a
            P, Q = ray.direct.cross(E2), T.cross(E1)
            t, u, v = Vec3(Q.dot(E2), P.dot(T), Q.dot(ray.direct)) / P.dot(E1)
            is_hit = 1e-5 < t < closest and u > 0 and v > 0 and u + v <= 1
            return is_hit, t, u, v

        @ti.func
        def hitAABB(self, id: int, ray: Ray3, closest: float) -> bool:
            i1 = (Model.lBound[id] - ray.origin) / ray.direct
            i2 = (Model.uBound[id] - ray.origin) / ray.direct
            return max(min(i1, i2).max(), 0) <= min(max(i1, i2).min(), closest)

        @ti.func
        def intersect(self, ray: Ray3, closest: float):
            closest_ = closest
            curr = Model.n - 1
            hitId = -1
            u = v = 0.0

            while curr >= 0:
                if Model.hitAABB(self, curr, ray, closest_):
                    if Model.index[curr] >= 0:
                        hit, t, _u, _v = Model.hitTriangle(self, curr, ray, closest_)
                        if hit:
                            closest_ = t
                            hitId = curr
                            u, v = _u, _v
                    curr = Model.hitNext[curr]
                else:
                    curr = Model.missNext[curr]
            normal = Vec3(0)
            if closest_ < closest:
                if ti.static(mixNorm):
                    n = Model.norm[Model.index[hitId]]
                    normal = n[0, :] * (1 - u - v) + n[1, :] * u + n[2, :] * v
                else:
                    mat = Model.data[Model.index[hitId]]
                    a, b, c = mat[0, :], mat[1, :], mat[2, :]
                    normal = (b - a).cross(c - a)
            return closest_, normal.normalized()

    missNext = np.empty((len(nodes),), dtype=int)
    missNext[-1] = -1
    for i in range(len(nodes) - 2, -1, -1):
        parent = nodes[i].parent
        rc = nodes[parent].rChild
        missNext[i] = rc if rc != i else missNext[parent]
    hitNext = missNext.copy()
    for i in range(len(nodes)):
        if nodes[i].lChild >= 0:
            hitNext[i] = nodes[i].lChild
        elif nodes[i].rChild >= 0:
            hitNext[i] = nodes[i].rChild

    Model.data.from_numpy(objs)
    Model.norm.from_numpy(norms)
    Model.missNext.from_numpy(missNext)
    Model.hitNext.from_numpy(hitNext)
    Model.index.from_numpy(np.array([node.index for node in nodes]))
    Model.lBound.from_numpy(np.array([node.lBound for node in nodes]))
    Model.uBound.from_numpy(np.array([node.uBound for node in nodes]))
    Logger.log(f'built {filename} in {time() - t0:.3f} s')
    return Model


def showModel(v: np.ndarray, f: np.ndarray, showWire: bool = False) -> None:
    """在 taichi GUI 中预览模型"""
    v -= v.mean()
    light1, light2 = v.min(axis=0), v.max(axis=0)
    v /= (light2 - light1).max()
    pos = ti.Vector.field(3, float, v.shape[0])
    pos.from_numpy(v)
    face = ti.field(int, f.shape[0] * 3)
    face.from_numpy(f.flatten())
    window = ti.ui.Window("presenting model", (1024, 1024))
    canvas = window.get_canvas()
    scene = ti.ui.Scene()
    camera = ti.ui.Camera()
    camera.position(0, 0, -3)
    while window.running:
        camera.track_user_inputs(window, movement_speed=0.03, hold_key=ti.ui.LMB)
        scene.set_camera(camera)

        scene.point_light(pos=light1, color=(1, 0, 1))
        scene.point_light(pos=light2, color=(1, 1, 0))
        scene.ambient_light((0.5, 0.5, 0.5))

        scene.particles(pos, radius=0.001, color=(0, 1, 1))
        scene.mesh(pos, indices=face, show_wireframe=showWire)

        canvas.scene(scene)
        window.show()
